import pandas as pd
from datasets import load_dataset, Dataset
from transformers import pipeline, logging
from tqdm import tqdm
import os
import re
import torch
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error()

###########################
## Set device from CLI args
###########################
import os
import argparse
def get_device():
    # Allow --device on the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=None, help="Device to use: cpu, mps, cuda")
    args, _ = parser.parse_known_args()  # ignore unknown args so script still works normally
    # Priority: CLI argument > DEVICE environment variable > default=cpu
    return args.device or os.getenv("DEVICE", "cuda")
DEVICE = get_device()
print(f"[INFO] Using DEVICE={DEVICE}")
# PyTorch setup
import torch
if DEVICE == "cpu":
    device = torch.device("cpu")
elif DEVICE == "mps":
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
elif DEVICE == "cuda":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
    raise ValueError(f"Unknown device {DEVICE}")
print(f"[INFO] Torch device set to {device}")

###########################
## Data, functions, etc.
###########################
test = pd.read_csv('./data/polnli_test_results.csv')
alts = pd.read_csv('./data/hypothesis_variants.csv')

# merge with alternative hypotheses
test = pd.merge(test, alts, how = 'left', left_on = 'hypothesis', right_on = 'original')

# drop observations that don't have AI generated alternatives
test = test[~test['original'].isna()].reset_index(drop = True)

# drop columns I don't need
test.drop(['augmented_hypothesis', 'base_nli', 'large_nli', 'base_debate', 'large_debate', 'base_modern', 'large_modern', 'llama', 'sonnet', 'original', 'alt4'], axis = 1, inplace = True)

def stability_benchmark(model_folder, test_df, batch_size=16):
    """
    Runs a stability benchmark across checkpoints and RETURNS a DataFrame
    with all labels for all checkpoints.

    Parameters:
    - model_folder: str, path to folder containing checkpoint-* directories
    - test_df: pd.DataFrame, dataframe containing at least 'premise' and the
               hypothesis/alt columns (text for hypothesis columns). This function
               will add label columns to a copy of this DataFrame for each checkpoint.
    - batch_size: int, batch size for the pipeline

    Returns:
    - labels_df: pd.DataFrame with columns:
        ['doc_index', 'Checkpoint', 'hypothesis_label', 'alt1_label', 'alt2_label', 'alt3_label', ...original cols...]
      Each row corresponds to one document x one checkpoint.
    """
    # collect checkpoint folders sorted by numeric checkpoint id
    checkpoints = sorted(
        [
            os.path.join(model_folder, d) for d in os.listdir(model_folder)
            if os.path.isdir(os.path.join(model_folder, d)) and re.match(r'checkpoint-\d+', d)
        ],
        key=lambda x: int(re.search(r'checkpoint-(\d+)', x).group(1))
    )

    all_checkpoint_dfs = []

    # expected text columns in the test dataframe
    colnames = ['hypothesis', 'alt1', 'alt2', 'alt3']

    for model_path in tqdm(checkpoints, desc="Processing models"):
        # load pipeline for this checkpoint
        pipe = pipeline('text-classification', model=model_path, batch_size=batch_size, torch_dtype=torch.bfloat16)

        # make a shallow copy of the test_df and preserve doc index
        df_copy = test_df.copy()
        df_copy = df_copy.reset_index(drop=False).rename(columns={'index': 'doc_index'})  # keep original index as doc_index

        # classify each hypothesis/alt column and store label column
        for col in tqdm(colnames, desc=f'Classifying columns in {model_path}', leave=False):
            colname = col + '_label'
            # prepare inputs as dicts for the pipeline (text + text_pair)
            dicts = [{'text': df_copy.loc[i, 'premise'], 'text_pair': df_copy.loc[i, col]} for i in df_copy.index]
            res = pipe(dicts, truncation=True)
            labels = [r['label'] for r in res]
            # map textual labels to ints (match your original mapping)
            df_copy[colname] = pd.Series(labels, index=df_copy.index).replace({'entailment': 0, 'not_entailment': 1}).astype(int)

        # Add checkpoint number column
        checkpoint_number = int(re.search(r'checkpoint-(\d+)', model_path).group(1))
        df_copy['Checkpoint'] = checkpoint_number

        # keep only index, checkpoint and label columns (plus any columns you want to keep)
        keep_cols = ['doc_index', 'Checkpoint'] + [c + '_label' for c in colnames]
        # also keep any original text columns if desired (uncomment next line)
        # keep_cols += ['premise'] + colnames
        all_checkpoint_dfs.append(df_copy[keep_cols])

        print(f"Checkpoint {checkpoint_number} completed")

    # concat all checkpoints into a single DataFrame
    labels_df = pd.concat(all_checkpoint_dfs, ignore_index=True, sort=False)
    # sort by checkpoint then doc_index for ergonomics
    labels_df.sort_values(['Checkpoint', 'doc_index'], inplace=True, ignore_index=True)

    return labels_df

def compute_stability(labels_df, label_cols=None, group_by=None):
    """
    Compute stability scores from a labels DataFrame, grouping by the provided columns.

    Parameters
    - labels_df: pd.DataFrame produced by `stability_benchmark_all_labels`.
                 Must contain the label columns and any grouping columns you want to use.
    - label_cols: list of str, names of the label columns to use. If None,
                  defaults to ['hypothesis_label','alt1_label','alt2_label','alt3_label'].
    - group_by: str or list of str. Column name(s) to group by. If None, defaults to ['Checkpoint'].

    Returns
    - pd.DataFrame with columns = group_by (if provided) + ['Stability'].
      Each row contains the stability score for that group.
    """
    if label_cols is None:
        label_cols = ['hypothesis_label', 'alt1_label', 'alt2_label', 'alt3_label']

    # normalize group_by to a list
    if group_by is None:
        group_by = ['Checkpoint']
    elif isinstance(group_by, str):
        group_by = [group_by]
    elif isinstance(group_by, (list, tuple)):
        group_by = list(group_by)
    else:
        raise TypeError("group_by must be None, a string, or a list/tuple of strings")

    # verify group_by columns exist
    missing = [c for c in group_by if c not in labels_df.columns]
    if missing:
        raise ValueError(f"The following group_by columns are not present in labels_df: {missing}")

    n_prompts = len(label_cols)
    pairs_per_row = (n_prompts * (n_prompts - 1)) // 2

    def differing_pairs_for_row(row):
        """
        For a Series of labels (length n_prompts), compute how many of the
        C(n_prompts, 2) pairs are different.
        """
        counts = row.value_counts(dropna=False)
        same_label_pairs = sum((c * (c - 1)) // 2 for c in counts)
        return pairs_per_row - same_label_pairs

    results = []
    # group and compute per-group stability
    grouped = labels_df.groupby(group_by, sort=True)
    for key, group in grouped:
        # compute differing pairs per document
        diffs_series = group[label_cols].apply(differing_pairs_for_row, axis=1)
        total_differences = int(diffs_series.sum())
        total_comparisons = pairs_per_row * len(group)
        stability = total_differences / total_comparisons if total_comparisons > 0 else float('nan')

        # construct result row mapping group_by columns to their values
        if isinstance(key, tuple):
            row = dict(zip(group_by, key))
        else:
            row = {group_by[0]: key}
        row['Stability'] = stability
        results.append(row)

    stability_df = pd.DataFrame(results)
    # keep consistent ordering
    if group_by:
        stability_df.sort_values(group_by, inplace=True, ignore_index=True)
    else:
        stability_df = stability_df.reset_index(drop=True)

    return stability_df

####################
## Modern Bert Base
####################
mbb = stability_benchmark("training_ModernBase", test_df = test, batch_size = 128)
mbb['Model'] = 'Modern BERT Base'
#mbb.to_csv('./data/mbb_stability_labs.csv', index = False)

####################
## Modern Bert Large
####################
mbl = stability_benchmark("training_ModernLarge", test_df = test, batch_size = 128)
mbl['Model'] = 'Modern BERT Large'
#mbl.to_csv('./data/mbl_stability_labs.csv', index = False)

####################
## DeBERTa Base
####################
db = stability_benchmark("training_base", test_df = test, batch_size = 8)
db['Model'] = 'DeBERTa Base'
#db.to_csv('./data/db_stability_labs.csv', index = False)

####################
## DeBERTa Large
####################
dl = stability_benchmark("training_large2", test_df = test, batch_size = 8)
dl['Model'] = 'DeBERTa Large'
#dl.to_csv('./data/dl_stability_labs.csv', index = False)

###################
## Compile Results
###################
all_labels = pd.concat([mbb, mbl, db, dl])
all_labels.to_csv('./data/stability_labels.csv', index = False)